{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4324a245",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys  \n",
    "import json\n",
    "import torch \n",
    "import argparse\n",
    "import numpy as np\n",
    "from PIL import Image \n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset\n",
    "from utils import model_gen, extract_answer\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer  \n",
    "\n",
    "import openai\n",
    "openai.api_key=''\n",
    "\n",
    "dataset = load_dataset(\"AI4Math/MathVista\")\n",
    "\n",
    "ckpt_path = 'internlm/internlm-xcomposer2-vl-7b'\n",
    "tokenizer = AutoTokenizer.from_pretrained(ckpt_path, trust_remote_code=True)\n",
    "model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=\"cuda\", trust_remote_code=True).eval().cuda().to(torch.bfloat16)\n",
    "model.tokenizer = tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e77a6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "final = {}\n",
    "for d in tqdm(dataset[\"testmini\"]):\n",
    "    q = d['query'] \n",
    "    text = f\"[UNUSED_TOKEN_146]user\\n{q}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n\" \n",
    "    image = d['decoded_image']\n",
    "    with torch.cuda.amp.autocast():\n",
    "        with torch.no_grad():\n",
    "            response = model_gen(model, text, image) \n",
    "            final[d['pid']] = response "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ccf7ad3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "full_out = {}\n",
    "for d in tqdm(dataset[\"testmini\"]): \n",
    "    res = extract_answer(final[d['pid']], d) \n",
    "    d['extraction'] = res\n",
    "    d['decoded_image'] = ''\n",
    "    full_out[d['pid']] = d "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "7a2cafb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.dump(full_out, open( 'MathVista_Testmini_InternLM_XComposer_VL.json', 'w'))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
